Skip to content

Add examples for MoE models - Mixtral in TE#2642

Open
faradawn wants to merge 57 commits into
NVIDIA:mainfrom
faradawn:add-moe-example
Open

Add examples for MoE models - Mixtral in TE#2642
faradawn wants to merge 57 commits into
NVIDIA:mainfrom
faradawn:add-moe-example

Conversation

@faradawn
Copy link
Copy Markdown
Contributor

@faradawn faradawn commented Feb 2, 2026

Summary

This PR adds a complete tutorial for integrating HuggingFace Mixtral (MoE) with Transformer Engine, addressing the gap identified in #2573.

What's included

  • te_mixtral.py — Drop-in TEMixtralSparseMoeBlock that replaces HF's loop-over-experts with TE's GroupedLinear (batched GEMM) + moe_permute/moe_unpermute. Includes replace_moe_block context manager, TEMixtralForCausalLM with HF weight loading, and replace_params for expert weight packing.
  • utils.py — Data loading, BF16/FP8 model init, Accelerate wrapping, fine-tuning loop — mirrors te_llama/utils.py style.
  • requirements.txt — Pinned dependencies matching the Llama/Gemma tutorials.
  • Tutorial notebook — Full tutorial matching the quality bar of te_llama and te_gemma, covering:
    1. Architecture overview: Transformer → Mixtral MoE, HF bottleneck, TE approach
    2. Unit-test cell verifying output shape/dtype against the HF block
    3. [Baseline] HF Mixtral in BF16
    4. [Improvement 1] TE GroupedLinear MoE in BF16
    5. [Improvement 2] TE GroupedLinear MoE in FP8
    6. Expert routing considerations with mixed precision (m_splits, per-expert FP8 scaling, aux loss passthrough)
    7. Generalisation guide for other MoE architectures (DeepSeek, Grok-1, etc.)

Bug fix

Corrected the m_splits calculation flagged by the automated review:

# Before (wrong): double-counts tokens by reducing with .any() then multiplying by top_k
expert_mask = (selected_experts == expert_idx).any(dim=-1)
m_splits.append(expert_mask.sum().item() * self.top_k)

# After (correct): count the actual number of (token, top_k_slot) pairs per expert
m_splits = [(selected_experts == i).sum().item() for i in range(self.num_experts)]

Scope

Covers all topics requested in #2573:

  • How to wrap MoE layers with TE modules ✓
  • FP8 training configuration for MoE ✓
  • Expert routing considerations with mixed precision ✓
  • Generalisation to arbitrary MoE architectures ✓

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 2, 2026

Greptile Summary

This PR adds a comprehensive tutorial for integrating HuggingFace Mixtral (MoE) with Transformer Engine, covering TE GroupedLinear, expert parallelism via NCCL all-to-all, quantized recipes, and a fused tier-9 path (te_mixtral_mxfp8.py + te_all2all.py). Multiple issues raised in prior review rounds have been corrected (class renames, m_splits calculation, map_type="index", merging_probs rename, warmup-steps hardcoding, missing collator dependency, wandb removal).

  • te_mixtral.py loop-mode EP crash: In _expert_ffn loop mode, gate_up_w.to_local() is called on an nn.Parameter that wraps a DTensor after set_ep_group(); nn.Parameter has no to_local(), so every EP forward pass in this mode raises AttributeError.
  • Unresolved prior issues: Auto-regressive decode assertion (should_pack_inputs), input_ids crash when inputs_embeds is passed, and a non-standard num_out_tokens sentinel in moe_permute remain open in both te_mixtral.py and te_mixtral_mxfp8.py.
  • Missing flash-attn in requirements.txt: Both baseline and TE model loaders set _attn_implementation="flash_attention_2", causing an import error for users who install only what the requirements file lists.

Confidence Score: 2/5

Several forward-pass crashes remain unresolved across both model files; the PR is not ready to merge in its current state.

The loop-mode EP path in te_mixtral.py crashes on every forward call when ep_size > 1 due to calling a nonexistent method on nn.Parameter. This compounds previously flagged but still-open bugs: the should_pack_inputs assertion fires on every decode step, input_ids.shape crashes when callers pass inputs_embeds inside an InferenceParams branch, and flash-attn is required but absent from the dependency file.

te_mixtral.py (loop-mode EP forward crash + decode assertion + inputs_embeds guard) and te_mixtral_mxfp8.py (decode assertion + inputs_embeds guard) need the most attention before merge.

Important Files Changed

Filename Overview
docs/examples/te_mixtral/te_mixtral.py Core TE Mixtral model (1641 lines). Contains a crash in _expert_ffn loop mode when EP > 1 due to calling .to_local() on nn.Parameter instead of its .data DTensor. Several other known crashes (inputs_embeds + InferenceParams, decode-step assertion) remain unresolved per earlier review threads.
docs/examples/te_mixtral/te_mixtral_mxfp8.py Tier-9 MXFP8 Mixtral model. Still carries the input_ids.shape crash when inputs_embeds is supplied and the should_pack_inputs assertion crash during auto-regressive decode (flagged in earlier threads and not yet fixed).
docs/examples/te_mixtral/utils.py Training utilities: warmup-steps and wandb issues addressed; NVMixtralForCausalLM import fixed; strict=False load_state_dict swallowing weight-mapping failures still present (flagged earlier).
docs/examples/te_mixtral/te_all2all.py NCCL all-to-all dispatcher for tier-9 MXFP8. Clean implementation; correctly omits merging_probs in combine() since ScaledSwiGLU handles it.
docs/examples/te_mixtral/run_finetune_ep.py Multi-tier fine-tuning launcher. Tier labels and improvement indices are now internally consistent.
docs/examples/te_mixtral/requirements.txt Missing flash-attn dependency (flagged in earlier thread); wandb usage removed from code so that omission is now moot.
docs/examples/te_mixtral/collator.py DataCollatorWithFlattening now shipped alongside utils.py; resolves the bionemo_mixtral import error flagged and marked fixed in prior review.
docs/examples/te_mixtral/test_accuracy.py Correct weight-mapping parity test: captures load_state_dict return value and asserts only _extra_state keys are missing.

Reviews (29): Last reviewed commit: "Merge branch 'add-moe-example' of github..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

Thanks, @faradawn! Also adding @sbhavani to the discussion. Compared to other llama/gemma tutorials, this one seems a quite barebones and looks more like a code example than a tutorial. @sbhavani do you think in its current form, it covers the scope as you requested in #2573?

Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/te_mixtral.py Outdated
Comment thread docs/examples/te_mixtral/te_mixtral.py Outdated
Comment thread docs/examples/te_mixtral/utils.py
@faradawn
Copy link
Copy Markdown
Contributor Author

faradawn commented Apr 2, 2026

Hi @sudhakarsingh27 can you check if this addresses your comments? Tested in 2x H100.

@sbhavani
Copy link
Copy Markdown
Collaborator

sbhavani commented Apr 6, 2026

Thanks, @faradawn! Also adding @sbhavani to the discussion. Compared to other llama/gemma tutorials, this one seems a quite barebones and looks more like a code example than a tutorial. @sbhavani do you think in its current form, it covers the scope as you requested in #2573?

Agreed! I think any example should show some perf gain and include the whole weight mapping so the user can run the example.

@pggPL pggPL self-assigned this Apr 13, 2026
@pggPL
Copy link
Copy Markdown
Collaborator

pggPL commented Apr 13, 2026

Documentation build is not working, if you fix it please ping me and I'll review.

Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/utils.py Outdated
Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/utils.py Outdated
Comment thread docs/examples/te_mixtral/HANDOFF.md Outdated
Comment thread docs/examples/te_mixtral/te_mixtral.py Outdated
faradawn and others added 11 commits April 21, 2026 11:29
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
- Fix m_splits bug: use (selected_experts == i).sum() instead of
  .any(dim=-1).sum() * top_k, which caused dimension mismatches in
  GroupedLinear
- Add te_mixtral.py: TEMixtralSparseMoeBlock (GroupedLinear + moe_permute/
  unpermute), replace_moe_block context manager, TEMixtralForCausalLM with
  HF weight loading, and replace_params for expert weight packing
- Add utils.py: HyperParameters, data loading, BF16/FP8 model init,
  Accelerate wrapping, fine-tuning loop — mirrors te_llama/utils.py style
- Add requirements.txt matching te_llama versions
- Expand notebook from bare code snippet to full tutorial covering:
  architecture overview, HF vs TE comparison table, unit-test cell,
  baseline BF16 run, BF16 TE improvement, FP8 TE improvement, expert
  routing/scaling discussion, generalisation guide for other MoE models

Addresses reviewer feedback: fixes the critical runtime bug (greptile)
and expands to tutorial quality comparable to the Llama/Gemma examples
(sudhakarsingh27), covering the scope requested in issue NVIDIA#2573.

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
A few bugs found during code review:

- moe_permute and moe_unpermute were missing map_type='index'. The
  selected_experts tensor is [num_tokens, top_k] indices, not a mask,
  so without this the routing is completely wrong at runtime.
- num_out_tokens=None should be -1 (the API expects an int).
- moe_unpermute: replaced deprecated probs= with merging_probs= and
  kept routing_weights in float32 as TE recommends.
- utils.py: num_warmup_steps was hardcoded to 100 instead of using
  hyperparams.num_warmup_steps, which made the benchmark meaningless.
- requirements.txt: transformers==4.57.0 doesn't exist, fixed to 4.47.1.
- Notebook generalisation guide: updated code template with the same fixes.

Tested on 2xH100 in nvcr.io/nvidia/pytorch:26.01-py3 (PyTorch 2.10.0, CUDA 13.1):

  $ python3 -c "
  from te_mixtral import TEMixtralSparseMoeBlock
  from transformers import MixtralConfig
  from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
  import torch

  cfg = MixtralConfig(hidden_size=256, intermediate_size=512,
                      num_local_experts=4, num_experts_per_tok=2)
  x = torch.randn(2, 8, cfg.hidden_size, device='cuda', dtype=torch.bfloat16)
  hf_out, hf_logits = MixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
  te_out, te_logits = TEMixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
  assert hf_out.shape == te_out.shape
  assert hf_logits.shape == te_logits.shape
  print('PASS')"

  Input  shape : torch.Size([2, 8, 256])
  Output shape : torch.Size([2, 8, 256])  (matches HF: True)
  Logits shape : torch.Size([16, 4])  (matches HF: True)
  Output dtype : torch.bfloat16
  PASS

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
The docs CI was failing because nbsphinx auto mode tries to execute
notebooks with no stored outputs. Add explicit execute:never metadata
so the docs builder renders the notebook as-is without running cells.

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…otebook

te_mixtral.py:
- Replace per-expert .item() loop with torch.bincount + .tolist() (8 GPU syncs
  → 1), eliminating the main cause of GroupedLinear speedup regression
- Remove unnecessary num_out_tokens/max_token_num args from moe_permute
- Fix replace_params to use .copy_() instead of fragile .data[] assignment,
  and load_state_dict from fully-populated te_state in one shot
- Add device_map="auto" support in from_hf_model via accelerate.dispatch_model
- Rename replace_params -> _pack_expert_weights for clarity

tutorial_accelerate_hf_mixtral_with_te.ipynb:
- Remove shape-check section (redundant)
- Add FP8 prefill and decode-regime benchmark cells with summary table
- Restructure training sections to match te_llama pattern: restart at top of
  each cell, combined hyperparams + init + wrap + finetune, concise markdown
- Bump benchmark SEQ 512 -> 2048 for realistic H100 workload

utils.py:
- Apply user batch size and sequence length adjustments

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
@@ -0,0 +1,275 @@
{
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Notebook Tier bash commands misaligned with run_finetune_ep.py improvement indices

The bash commands shown for Tiers 2–4 reference --improvement values that map to different configurations in run_finetune_ep.py than the notebook describes. Concretely:

Notebook Tier Described as Bash uses run_finetune_ep.py actually does
2 TE GroupedLinear EP --improvement 2 expert_ffn_mode="loop" (Python F.linear, not GroupedLinear)
3 Fused DeepEP dispatcher (BF16) --improvement 3 expert_ffn_mode="grouped_op", no DeepEP
4 MXFP8 precision --improvement 4 DeepEP BF16 GroupedLinear, not MXFP8

Additionally, every multi-GPU bash command omits --ep-size 8, so the default --ep-size 2 (4 experts/rank) is used, whereas the Python cells in each tier all set hp.expert_parallel_size = 8 (1 expert/rank). The benchmark step-time numbers in the results tables were measured at EP=8, so users following the bash commands will reproduce different performance characteristics.

The correct mapping from notebook tier to --improvement flag appears to be: Tier 2 → --improvement 6, Tier 3 → --improvement 4, Tier 4 → --improvement 8 (all with --ep-size 8 added).

faradawn and others added 3 commits May 7, 2026 10:14
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +1 to +10
torch
torchao!=0.14.0
transformer_engine[pytorch]

transformers
accelerate
datasets
safetensors
huggingface_hub
tokenizers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing flash-attn dependency

Both init_baseline_model and init_te_mixtral_model in utils.py explicitly set config._attn_implementation = "flash_attention_2". HuggingFace transformers validates this at from_pretrained time and raises ImportError if the flash-attn package is absent. Every tutorial user following the baseline or TE fine-tuning paths hits this error before the first training step. Add flash-attn to requirements.txt (the flash-attn PyPI package).

faradawn added 2 commits May 8, 2026 07:24
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +236 to +237
model.load_state_dict(te_state_dict, strict=False)
del hf_model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent weight-mapping failures in init_te_mixtral_model

load_state_dict(strict=False) swallows all missing and unexpected keys, so if replace_params fails to map any parameter (e.g., due to a checkpoint key-name mismatch), the model silently trains with random initialization. test_accuracy.py shows the correct pattern: capture the return value and assert that the only missing keys are TE-internal _extra_state entries.

Suggested change
model.load_state_dict(te_state_dict, strict=False)
del hf_model
missing, unexpected = model.load_state_dict(te_state_dict, strict=False)
if unexpected:
raise RuntimeError(f"Unexpected keys when loading TE state dict: {unexpected}")
non_extra_missing = [k for k in missing if not k.endswith("_extra_state")]
if non_extra_missing:
raise RuntimeError(f"Missing non-extra-state keys in TE model: {non_extra_missing}")
del hf_model

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +979 to +984
should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd"

if should_pack_inputs:
assert (
attention_mask is not None
), "Attention mask is required when packing BSHD inputs."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Decode-step crash with default attn_input_format="thd" when attention_mask is omitted

should_pack_inputs is True whenever attn_input_format == "thd" and no explicit THD kwargs are supplied, which is the normal case for every decode step during generation. In that path the assertion fires immediately because HF generate() typically does not propagate attention_mask on decode steps (sequence length 1). Adding a guard like past_key_values is None (or hidden_states.size(1) > 1) to should_pack_inputs would skip packing for single-token decode steps where the sequence is already effectively packed.

routing_weights_for_unpermute = routing_weights
map_type = "index"

if self._ep_group is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-standard sentinel for no-token-dropping in moe_permute

The TE API documents -1 as the sentinel meaning "no token dropping"; the value used here (0) only happens to work because the fake/shape-inference path guards with num_out_tokens > 0, and 0 fails that check. If a future TE CUDA kernel instead treats 0 as a literal output-token count, permuted_hidden becomes empty and every all-to-all dispatch silently sends no tokens to any expert. Use -1 to match the documented contract.

faradawn and others added 2 commits May 10, 2026 18:20
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +622 to +628
if isinstance(past_key_values, InferenceParams):
lengths = (
attention_mask.sum(dim=1).tolist()
if attention_mask.shape == input_ids.shape
else [1] * input_ids.shape[0]
)
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 input_ids crash when inputs_embeds is supplied

When a caller passes inputs_embeds instead of input_ids, input_ids is None. Both branches of the ternary — attention_mask.shape == input_ids.shape and [1] * input_ids.shape[0] — then raise AttributeError: 'NoneType' object has no attribute 'shape'. The fix is to derive the reference tensor from whichever of input_ids/inputs_embeds is non-None, mirroring the guard already needed in the parent te_mixtral.py.

Comment on lines +587 to +588
if should_pack_inputs:
assert attention_mask is not None, "attention_mask required when packing BSHD."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Assertion crash during auto-regressive decode

should_pack_inputs is True whenever attn_input_format == "thd" and no THD kwargs are already present, which is the default for every decode step in model.generate(). On decode steps HuggingFace typically does not forward attention_mask, so this assertion fires immediately: AssertionError: attention_mask required when packing BSHD. A guard like past_key_values is None (or hidden_states.size(1) > 1) would skip packing for single-token steps where packing is unnecessary.

faradawn and others added 5 commits May 11, 2026 07:40
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
@faradawn
Copy link
Copy Markdown
Contributor Author

Added accuracy test between TE and HF - passed

Max abs diff  : 0.3516
  Mean abs diff : 0.0220
  HF loss       : 7.0018
  TE loss       : 6.9995
  MXFP8 parity OK

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
@vthumbe1503 vthumbe1503 self-requested a review May 12, 2026 21:58
faradawn added 4 commits May 12, 2026 23:07
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +707 to +711
if isinstance(gate_up_w, DTensor) or isinstance(gate_up_w.data, DTensor):
gate_up_w = gate_up_w.to_local()
down_w = self.experts_down_weight
if isinstance(down_w, DTensor) or isinstance(down_w.data, DTensor):
down_w = down_w.to_local()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 to_local() called on nn.Parameter, not on DTensor — crashes under EP

After set_ep_group(), self.experts_gate_up_weight is nn.Parameter(DTensor.from_local(...)). The check isinstance(gate_up_w.data, DTensor) correctly identifies this case, but then calls gate_up_w.to_local() on the nn.Parameter wrapper. nn.Parameter is a separate torch.Tensor subclass and does not define to_local(), so every forward pass in loop mode with ep_size > 1 raises AttributeError: 'Parameter' object has no attribute 'to_local'.

The in-code comment correctly explains that .data must be avoided to preserve the autograd graph, but the fix cannot be simply calling to_local() on the Parameter either. One approach that preserves the graph is to re-wrap the local shard: gate_up_w = type(gate_up_w)(gate_up_w.data.to_local()). The same issue applies to down_w on lines 710-711.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants